{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.llms import OpenAI\n",
    "from langchain.prompts import PromptTemplate \n",
    "from langchain.chains.llm import LLMChain\n",
    "from cot import Collection\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Chain to retrieve a chain-of-thought\"\"\"\n",
    "llm = OpenAI(temperature=.0)\n",
    "template = \"\"\"{instruction}\n",
    "\n",
    "Question: {question}\n",
    "Answer_choices: {answer_choices}\n",
    "\n",
    "{cot_trigger}\n",
    "\"\"\"\n",
    "prompt_template = PromptTemplate(input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\"], template=template)\n",
    "cot_chain = LLMChain(llm=llm, prompt=prompt_template,output_key=\"cot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"This chain retrieves answer extraction\"\"\"\n",
    "extraction_template = \"\"\"{instruction}\n",
    "\n",
    "Question: {question}\n",
    "Answer_choices: {answer_choices}\n",
    "\n",
    "Cot: {cot_trigger}{cot}\n",
    "{answer_extraction}\n",
    "\"\"\"\n",
    "\n",
    "prompt_template = PromptTemplate(input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\",\"cot\",\"answer_extraction\"], template=extraction_template)\n",
    "answer_chain = LLMChain(llm=llm, prompt=prompt_template,output_key=\"answer\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Combine chains\"\"\"\n",
    "\n",
    "from langchain.chains import SequentialChain\n",
    "overall_chain = SequentialChain(chains=[cot_chain, answer_chain],\n",
    "                                input_variables=[\"instruction\",\"question\",\"answer_choices\",\"cot_trigger\",\"answer_extraction\"],\n",
    "                                output_variables=[\"cot\", \"answer\"],\n",
    "                                verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading worldtree...\n"
     ]
    }
   ],
   "source": [
    "input_dict = {\n",
    "    \"instruction\": \"Be faithful and a little hopeful\",\n",
    "    \"cot_trigger\": \"Answer: Let's think step by step.\",\n",
    "    \"answer_extraction_key\": ['kojima-A-D'], \n",
    "}\n",
    "\n",
    "\n",
    "worldtree = Collection([\"worldtree\"], verbose=False)\n",
    "worldtree_1 = worldtree.select(split=\"train\", number_samples=1, random_samples=True, seed=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating worldtree...\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "_self_generate_extract() got multiple values for argument 'instruction'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/Users/robertpraas/Desktop/ThoughtSource/notebooks/data_map_error.ipynb Cell 6\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/robertpraas/Desktop/ThoughtSource/notebooks/data_map_error.ipynb#W3sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39m#generate_extract_flexibly in dataloader.py\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/robertpraas/Desktop/ThoughtSource/notebooks/data_map_error.ipynb#W3sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m test \u001b[39m=\u001b[39m worldtree_1\u001b[39m.\u001b[39;49mgenerate_extract_flexibly(chain\u001b[39m=\u001b[39;49moverall_chain,input_dict\u001b[39m=\u001b[39;49minput_dict)\n",
      "File \u001b[0;32m~/Desktop/ThoughtSource/libs/cot/cot/dataloader.py:329\u001b[0m, in \u001b[0;36mCollection.generate_extract_flexibly\u001b[0;34m(self, chain, input_dict, name, split)\u001b[0m\n\u001b[1;32m    327\u001b[0m     \u001b[39mfor\u001b[39;00m name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_cache:\n\u001b[1;32m    328\u001b[0m         \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mGenerating \u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m...\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 329\u001b[0m         \u001b[39mself\u001b[39m[name] \u001b[39m=\u001b[39m self_generate_extract(\u001b[39mself\u001b[39;49m[name], chain, input_dict)\n\u001b[1;32m    330\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    331\u001b[0m     \u001b[39mif\u001b[39;00m split \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/Desktop/ThoughtSource/libs/cot/cot/new_generate.py:38\u001b[0m, in \u001b[0;36mself_generate_extract\u001b[0;34m(data, chain, input_dict)\u001b[0m\n\u001b[1;32m     33\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mNot recognized data\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m     36\u001b[0m input_dict[\u001b[39m'\u001b[39m\u001b[39mchain\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m chain\n\u001b[0;32m---> 38\u001b[0m \u001b[39mreturn\u001b[39;00m data\u001b[39m.\u001b[39;49mmap(\n\u001b[1;32m     39\u001b[0m     _self_generate_extract,\n\u001b[1;32m     40\u001b[0m     with_indices\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m     41\u001b[0m     fn_kwargs\u001b[39m=\u001b[39;49minput_dict,\n\u001b[1;32m     42\u001b[0m     features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m     43\u001b[0m     load_from_cache_file\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m     44\u001b[0m )\n\u001b[1;32m     47\u001b[0m new_dataset \u001b[39m=\u001b[39m []\n\u001b[1;32m     48\u001b[0m \u001b[39mfor\u001b[39;00m example \u001b[39min\u001b[39;00m data[\u001b[39m'\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m'\u001b[39m]: \u001b[39m#hardcoded\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/dataset_dict.py:438\u001b[0m, in \u001b[0;36mDatasetDict.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)\u001b[0m\n\u001b[1;32m    435\u001b[0m \u001b[39mif\u001b[39;00m cache_file_names \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    436\u001b[0m     cache_file_names \u001b[39m=\u001b[39m {k: \u001b[39mNone\u001b[39;00m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m}\n\u001b[1;32m    437\u001b[0m \u001b[39mreturn\u001b[39;00m DatasetDict(\n\u001b[0;32m--> 438\u001b[0m     {\n\u001b[1;32m    439\u001b[0m         k: dataset\u001b[39m.\u001b[39mmap(\n\u001b[1;32m    440\u001b[0m             function\u001b[39m=\u001b[39mfunction,\n\u001b[1;32m    441\u001b[0m             with_indices\u001b[39m=\u001b[39mwith_indices,\n\u001b[1;32m    442\u001b[0m             with_rank\u001b[39m=\u001b[39mwith_rank,\n\u001b[1;32m    443\u001b[0m             input_columns\u001b[39m=\u001b[39minput_columns,\n\u001b[1;32m    444\u001b[0m             batched\u001b[39m=\u001b[39mbatched,\n\u001b[1;32m    445\u001b[0m             batch_size\u001b[39m=\u001b[39mbatch_size,\n\u001b[1;32m    446\u001b[0m             drop_last_batch\u001b[39m=\u001b[39mdrop_last_batch,\n\u001b[1;32m    447\u001b[0m             remove_columns\u001b[39m=\u001b[39mremove_columns,\n\u001b[1;32m    448\u001b[0m             keep_in_memory\u001b[39m=\u001b[39mkeep_in_memory,\n\u001b[1;32m    449\u001b[0m             load_from_cache_file\u001b[39m=\u001b[39mload_from_cache_file,\n\u001b[1;32m    450\u001b[0m             cache_file_name\u001b[39m=\u001b[39mcache_file_names[k],\n\u001b[1;32m    451\u001b[0m             writer_batch_size\u001b[39m=\u001b[39mwriter_batch_size,\n\u001b[1;32m    452\u001b[0m             features\u001b[39m=\u001b[39mfeatures,\n\u001b[1;32m    453\u001b[0m             disable_nullable\u001b[39m=\u001b[39mdisable_nullable,\n\u001b[1;32m    454\u001b[0m             fn_kwargs\u001b[39m=\u001b[39mfn_kwargs,\n\u001b[1;32m    455\u001b[0m             num_proc\u001b[39m=\u001b[39mnum_proc,\n\u001b[1;32m    456\u001b[0m             desc\u001b[39m=\u001b[39mdesc,\n\u001b[1;32m    457\u001b[0m         )\n\u001b[1;32m    458\u001b[0m         \u001b[39mfor\u001b[39;00m k, dataset \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mitems()\n\u001b[1;32m    459\u001b[0m     }\n\u001b[1;32m    460\u001b[0m )\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/dataset_dict.py:439\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    435\u001b[0m \u001b[39mif\u001b[39;00m cache_file_names \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    436\u001b[0m     cache_file_names \u001b[39m=\u001b[39m {k: \u001b[39mNone\u001b[39;00m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m}\n\u001b[1;32m    437\u001b[0m \u001b[39mreturn\u001b[39;00m DatasetDict(\n\u001b[1;32m    438\u001b[0m     {\n\u001b[0;32m--> 439\u001b[0m         k: dataset\u001b[39m.\u001b[39;49mmap(\n\u001b[1;32m    440\u001b[0m             function\u001b[39m=\u001b[39;49mfunction,\n\u001b[1;32m    441\u001b[0m             with_indices\u001b[39m=\u001b[39;49mwith_indices,\n\u001b[1;32m    442\u001b[0m             with_rank\u001b[39m=\u001b[39;49mwith_rank,\n\u001b[1;32m    443\u001b[0m             input_columns\u001b[39m=\u001b[39;49minput_columns,\n\u001b[1;32m    444\u001b[0m             batched\u001b[39m=\u001b[39;49mbatched,\n\u001b[1;32m    445\u001b[0m             batch_size\u001b[39m=\u001b[39;49mbatch_size,\n\u001b[1;32m    446\u001b[0m             drop_last_batch\u001b[39m=\u001b[39;49mdrop_last_batch,\n\u001b[1;32m    447\u001b[0m             remove_columns\u001b[39m=\u001b[39;49mremove_columns,\n\u001b[1;32m    448\u001b[0m             keep_in_memory\u001b[39m=\u001b[39;49mkeep_in_memory,\n\u001b[1;32m    449\u001b[0m             load_from_cache_file\u001b[39m=\u001b[39;49mload_from_cache_file,\n\u001b[1;32m    450\u001b[0m             cache_file_name\u001b[39m=\u001b[39;49mcache_file_names[k],\n\u001b[1;32m    451\u001b[0m             writer_batch_size\u001b[39m=\u001b[39;49mwriter_batch_size,\n\u001b[1;32m    452\u001b[0m             features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m    453\u001b[0m             disable_nullable\u001b[39m=\u001b[39;49mdisable_nullable,\n\u001b[1;32m    454\u001b[0m             fn_kwargs\u001b[39m=\u001b[39;49mfn_kwargs,\n\u001b[1;32m    455\u001b[0m             num_proc\u001b[39m=\u001b[39;49mnum_proc,\n\u001b[1;32m    456\u001b[0m             desc\u001b[39m=\u001b[39;49mdesc,\n\u001b[1;32m    457\u001b[0m         )\n\u001b[1;32m    458\u001b[0m         \u001b[39mfor\u001b[39;00m k, dataset \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mitems()\n\u001b[1;32m    459\u001b[0m     }\n\u001b[1;32m    460\u001b[0m )\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:1955\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m   1952\u001b[0m disable_tqdm \u001b[39m=\u001b[39m \u001b[39mnot\u001b[39;00m logging\u001b[39m.\u001b[39mis_progress_bar_enabled()\n\u001b[1;32m   1954\u001b[0m \u001b[39mif\u001b[39;00m num_proc \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m num_proc \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m-> 1955\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_map_single(\n\u001b[1;32m   1956\u001b[0m         function\u001b[39m=\u001b[39;49mfunction,\n\u001b[1;32m   1957\u001b[0m         with_indices\u001b[39m=\u001b[39;49mwith_indices,\n\u001b[1;32m   1958\u001b[0m         with_rank\u001b[39m=\u001b[39;49mwith_rank,\n\u001b[1;32m   1959\u001b[0m         input_columns\u001b[39m=\u001b[39;49minput_columns,\n\u001b[1;32m   1960\u001b[0m         batched\u001b[39m=\u001b[39;49mbatched,\n\u001b[1;32m   1961\u001b[0m         batch_size\u001b[39m=\u001b[39;49mbatch_size,\n\u001b[1;32m   1962\u001b[0m         drop_last_batch\u001b[39m=\u001b[39;49mdrop_last_batch,\n\u001b[1;32m   1963\u001b[0m         remove_columns\u001b[39m=\u001b[39;49mremove_columns,\n\u001b[1;32m   1964\u001b[0m         keep_in_memory\u001b[39m=\u001b[39;49mkeep_in_memory,\n\u001b[1;32m   1965\u001b[0m         load_from_cache_file\u001b[39m=\u001b[39;49mload_from_cache_file,\n\u001b[1;32m   1966\u001b[0m         cache_file_name\u001b[39m=\u001b[39;49mcache_file_name,\n\u001b[1;32m   1967\u001b[0m         writer_batch_size\u001b[39m=\u001b[39;49mwriter_batch_size,\n\u001b[1;32m   1968\u001b[0m         features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m   1969\u001b[0m         disable_nullable\u001b[39m=\u001b[39;49mdisable_nullable,\n\u001b[1;32m   1970\u001b[0m         fn_kwargs\u001b[39m=\u001b[39;49mfn_kwargs,\n\u001b[1;32m   1971\u001b[0m         new_fingerprint\u001b[39m=\u001b[39;49mnew_fingerprint,\n\u001b[1;32m   1972\u001b[0m         disable_tqdm\u001b[39m=\u001b[39;49mdisable_tqdm,\n\u001b[1;32m   1973\u001b[0m         desc\u001b[39m=\u001b[39;49mdesc,\n\u001b[1;32m   1974\u001b[0m     )\n\u001b[1;32m   1975\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   1977\u001b[0m     \u001b[39mdef\u001b[39;00m \u001b[39mformat_cache_file_name\u001b[39m(cache_file_name, rank):\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:520\u001b[0m, in \u001b[0;36mtransmit_tasks.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    518\u001b[0m     \u001b[39mself\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mself\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    519\u001b[0m \u001b[39m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 520\u001b[0m out: Union[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDatasetDict\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    521\u001b[0m datasets: List[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(out\u001b[39m.\u001b[39mvalues()) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(out, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m [out]\n\u001b[1;32m    522\u001b[0m \u001b[39mfor\u001b[39;00m dataset \u001b[39min\u001b[39;00m datasets:\n\u001b[1;32m    523\u001b[0m     \u001b[39m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:487\u001b[0m, in \u001b[0;36mtransmit_format.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    480\u001b[0m self_format \u001b[39m=\u001b[39m {\n\u001b[1;32m    481\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_type,\n\u001b[1;32m    482\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mformat_kwargs\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_kwargs,\n\u001b[1;32m    483\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_format_columns,\n\u001b[1;32m    484\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39moutput_all_columns\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_output_all_columns,\n\u001b[1;32m    485\u001b[0m }\n\u001b[1;32m    486\u001b[0m \u001b[39m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 487\u001b[0m out: Union[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mDatasetDict\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    488\u001b[0m datasets: List[\u001b[39m\"\u001b[39m\u001b[39mDataset\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(out\u001b[39m.\u001b[39mvalues()) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(out, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m [out]\n\u001b[1;32m    489\u001b[0m \u001b[39m# re-apply format to the output\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/fingerprint.py:458\u001b[0m, in \u001b[0;36mfingerprint_transform.<locals>._fingerprint.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    452\u001b[0m             kwargs[fingerprint_name] \u001b[39m=\u001b[39m update_fingerprint(\n\u001b[1;32m    453\u001b[0m                 \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_fingerprint, transform, kwargs_for_fingerprint\n\u001b[1;32m    454\u001b[0m             )\n\u001b[1;32m    456\u001b[0m \u001b[39m# Call actual function\u001b[39;00m\n\u001b[0;32m--> 458\u001b[0m out \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    460\u001b[0m \u001b[39m# Update fingerprint of in-place transforms + update in-place history of transforms\u001b[39;00m\n\u001b[1;32m    462\u001b[0m \u001b[39mif\u001b[39;00m inplace:  \u001b[39m# update after calling func so that the fingerprint doesn't change if the function fails\u001b[39;00m\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:2320\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset, disable_tqdm, desc, cache_only)\u001b[0m\n\u001b[1;32m   2318\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m batched:\n\u001b[1;32m   2319\u001b[0m     \u001b[39mfor\u001b[39;00m i, example \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(pbar):\n\u001b[0;32m-> 2320\u001b[0m         example \u001b[39m=\u001b[39m apply_function_on_filtered_inputs(example, i, offset\u001b[39m=\u001b[39;49moffset)\n\u001b[1;32m   2321\u001b[0m         \u001b[39mif\u001b[39;00m update_data:\n\u001b[1;32m   2322\u001b[0m             \u001b[39mif\u001b[39;00m i \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:2220\u001b[0m, in \u001b[0;36mDataset._map_single.<locals>.apply_function_on_filtered_inputs\u001b[0;34m(inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m   2218\u001b[0m \u001b[39mif\u001b[39;00m with_rank:\n\u001b[1;32m   2219\u001b[0m     additional_args \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m (rank,)\n\u001b[0;32m-> 2220\u001b[0m processed_inputs \u001b[39m=\u001b[39m function(\u001b[39m*\u001b[39;49mfn_args, \u001b[39m*\u001b[39;49madditional_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mfn_kwargs)\n\u001b[1;32m   2221\u001b[0m \u001b[39mif\u001b[39;00m update_data \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m   2222\u001b[0m     \u001b[39m# Check if the function returns updated examples\u001b[39;00m\n\u001b[1;32m   2223\u001b[0m     update_data \u001b[39m=\u001b[39m \u001b[39misinstance\u001b[39m(processed_inputs, (Mapping, pa\u001b[39m.\u001b[39mTable))\n",
      "File \u001b[0;32m/opt/anaconda3/lib/python3.9/site-packages/datasets/arrow_dataset.py:1915\u001b[0m, in \u001b[0;36mDataset.map.<locals>.decorate.<locals>.decorated\u001b[0;34m(item, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1911\u001b[0m decorated_item \u001b[39m=\u001b[39m (\n\u001b[1;32m   1912\u001b[0m     Example(item, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m batched \u001b[39melse\u001b[39;00m Batch(item, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures)\n\u001b[1;32m   1913\u001b[0m )\n\u001b[1;32m   1914\u001b[0m \u001b[39m# Use the LazyDict internally, while mapping the function\u001b[39;00m\n\u001b[0;32m-> 1915\u001b[0m result \u001b[39m=\u001b[39m f(decorated_item, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1916\u001b[0m \u001b[39m# Return a standard dict\u001b[39;00m\n\u001b[1;32m   1917\u001b[0m \u001b[39mreturn\u001b[39;00m result\u001b[39m.\u001b[39mdata \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(result, LazyDict) \u001b[39melse\u001b[39;00m result\n",
      "\u001b[0;31mTypeError\u001b[0m: _self_generate_extract() got multiple values for argument 'instruction'"
     ]
    }
   ],
   "source": [
    "\"\"\"generate_extract_flexibly in dataloader.py calls self_generate_extract in new_generation.py\n",
    "The problem is in line 36 of new_generation.py the SequentialChain which reuses parts \n",
    "of the templates for the chains is part of the input dict (kwargs) - data.map\n",
    "does not seem to be able to handle a nested dict.\n",
    "\n",
    "The problem is that fn_kwargs takes inputs (fn_kwargs={\"tokenizer\":tokenizer}),\n",
    " we don't want that since we want to build the chain before calling data.map and not after\n",
    "\"\"\"\n",
    "test = worldtree_1.generate_extract_flexibly(chain=overall_chain,input_dict=input_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
